{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "677867a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chains import GraphCypherQAChain\n",
    "from langchain.chat_models import ChatOpenAI\n",
    "from langchain.graphs import Neo4jGraph\n",
    "from langchain.callbacks import get_openai_callback\n",
    "from dotenv import load_dotenv\n",
    "import os\n",
    "import openai\n",
    "import pandas as pd\n",
    "from neo4j.exceptions import CypherSyntaxError\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "186eb11d",
   "metadata": {},
   "source": [
    "## Choose the LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fe18f3fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "LLM_MODEL = 'gpt-4-32k'\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a040044",
   "metadata": {},
   "source": [
    "## Custom function for neo4j RAG chain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ead42cd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_neo4j_cypher_rag_chain():\n",
    "    load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))\n",
    "    username = os.environ.get('NEO4J_USER')\n",
    "    password = os.environ.get('NEO4J_PSW')\n",
    "    url = os.environ.get('NEO4J_URI')\n",
    "    database = os.environ.get('NEO4J_DB')\n",
    "\n",
    "    graph = Neo4jGraph(\n",
    "        url=url, \n",
    "        username=username, \n",
    "        password=password,\n",
    "        database = database\n",
    "    )\n",
    "\n",
    "    load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))\n",
    "    API_KEY = os.environ.get('API_KEY')\n",
    "    API_VERSION = os.environ.get('API_VERSION')\n",
    "    RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')\n",
    "    openai.api_type = \"azure\"\n",
    "    openai.api_key = API_KEY\n",
    "    openai.api_base = RESOURCE_ENDPOINT\n",
    "    openai.api_version = API_VERSION\n",
    "    chat_deployment_id = LLM_MODEL\n",
    "    chat_model_id = chat_deployment_id\n",
    "    temperature = 0\n",
    "    chat_model = ChatOpenAI(openai_api_key=API_KEY, \n",
    "                            engine=chat_deployment_id, \n",
    "                            temperature=temperature)\n",
    "    chain = GraphCypherQAChain.from_llm(\n",
    "        chat_model, \n",
    "        graph=graph, \n",
    "        verbose=True, \n",
    "        validate_cypher=True,\n",
    "        return_intermediate_steps=True\n",
    "    )\n",
    "    return chain"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc863aed",
   "metadata": {},
   "source": [
    "## Initiate neo4j RAG chain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9866fd11",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING! engine is not default parameter.\n",
      "                    engine was transferred to model_kwargs.\n",
      "                    Please confirm that engine is what you intended.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 17.5 ms, sys: 3.65 ms, total: 21.2 ms\n",
      "Wall time: 28.1 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "neo4j_rag_chain = get_neo4j_cypher_rag_chain()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3827d5c4",
   "metadata": {},
   "source": [
    "## Enter question -1 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7b9837c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"Is Parkinson's disease associated with PINK1 gene?\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23e669ed",
   "metadata": {},
   "source": [
    "## Run cypher RAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9f685b8e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
      "Generated Cypher:\n",
      "\u001b[32;1m\u001b[1;3mMATCH (d:Disease {name: \"Parkinson's disease\"}), (g:Gene {name: \"PINK1\"}) \n",
      "RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS is_associated\u001b[0m\n",
      "Full Context:\n",
      "\u001b[32;1m\u001b[1;3m[{'is_associated': True}]\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n",
      "Yes, Parkinson's disease is associated with the PINK1 gene.\n"
     ]
    }
   ],
   "source": [
    "\n",
    "out = neo4j_rag_chain.run(query=question, return_final_only=True, verbose=True)\n",
    "print(out)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "069bc2f2",
   "metadata": {},
   "source": [
    "## Question 1 perturbed (all smallcase)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "4cf3363b",
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"Is parkinson's disease associated with pink1 gene?\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "487fd17f",
   "metadata": {},
   "source": [
    "## Run cypher RAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "61d5eb0b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
      "Generated Cypher:\n",
      "\u001b[32;1m\u001b[1;3mMATCH (d:Disease {name: \"Parkinson's Disease\"}), (g:Gene {name: \"PINK1\"}) \n",
      "RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS is_associated\u001b[0m\n",
      "Full Context:\n",
      "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n",
      "I'm sorry, but I don't have the information to answer that question.\n"
     ]
    }
   ],
   "source": [
    "\n",
    "out = neo4j_rag_chain.run(query=question, return_final_only=True, verbose=True)\n",
    "print(out)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6123c1a7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
